import json
import argparse
import random
from trainer import train

def main(data_name, str_):
    args = setup_parser(data_name, str_).parse_args()
    param = load_json(args.config)
    args = vars(args)
    args.update(param)

    # fix some params
    if args['gpu_used'] == 999:
        pass
    else:
        GPU_ID = ["{}".format(args['gpu_used'])]
        args["device"] = GPU_ID

    if args['prompt_num'] == 999:
        pass
    else:
        PROMPT_NUM = args['prompt_num']
        args['prompt_token_num'] = PROMPT_NUM

    if 'attn' in str_:
        args['attn_store'] = 1

    if args['Block_which'] == "":
        pass
    else:
        real_block_which = list(map(int, args['Block_which'].split()))
        args['Block_which'] = real_block_which

    if args['Part_which'] == "":
        pass
    else:
        real_part_which = list(map(int, args['Part_which'].split()))
        args['Part_which'] = real_part_which

    if args['lamda_for_logit_loss'] == '':
        pass
    else:
        args['lamda_for_logit_loss'] = float(args['lamda_for_logit_loss'])

    if args['lamda_for_prompt'] == '':
        pass
    else:
        args['lamda_for_prompt'] = float(args['lamda_for_prompt'])

    if args['lamda_for_featureformer'] == '':
        pass
    else:
        args['lamda_for_featureformer'] = float(args['lamda_for_featureformer'])

    if args['lamda_for_featurelower'] == '':
        pass
    else:
        args['lamda_for_featurelower'] = float(args['lamda_for_featurelower'])

    if args['lamda_for_pool3'] == '':
        pass
    else:
        args['lamda_for_pool3'] = float(args['lamda_for_pool3'])

    if args['lamda0'] == '':
        args['lamda0'] = 0
    else:
        args['lamda0'] = float(args['lamda0'''])

    if args['train_list'] == "":
        pass
    else:
        real_train_list = list(map(int, args['train_list'].split()))
        args['train_list'] = real_train_list

    args['file_name'] = 'Temporary_Log' + '{}'.format(random.randint(10000, 20000)) if args['file_name'] == '' else args['file_name']
    args['print'] = 1
    train(args)

def load_json(settings_path):
    with open(settings_path) as data_file:
        param = json.load(data_file)
    return param


#===============================================================
def setup_parser(data_name, str_):
    parser = argparse.ArgumentParser(description='Reproduce of multiple continual learning algorthms.')
    parser.add_argument('--config', type=str, default='./exps/{}/{}.json'.format(data_name, str_),
                        help='Json file of settings.')

    parser.add_argument('--attn_store', type=int, default=0, help="whether to store attn map") # 0
    parser.add_argument('--change_data_order', type=int, default=0, help="whether to change the order of data ") # 0
    parser.add_argument('--task_num', type=int, default=10, help='how many change times when tuning?') # 10
    parser.add_argument('--gpu_used', type=int, default=999, help='chose gpu id if not 999')
    parser.add_argument('--prompt_num', type=int, default=999, help='chose prompt num if not 999')
    parser.add_argument('--use_tensorboard', type=int, default=1, help='chose use tensorboard or not?')

    parser.add_argument('--feature_distill_type', type=str, default='allheadallfeature', help='chose feature distill type?'
                        '---allheadallfeature   allheadpromptfeature    lastheadpromptfeature || default is 0。') # ''
    parser.add_argument('--file_name', type=str, default='', help='change the name of log file, only to see distinctly...')
    parser.add_argument('--lamda_for_logit_loss', type=str, default=0, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda_for_prompt', type=str, default=0, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda_for_featureformer', type=str, default=0, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda_for_featurelower', type=str, default=0, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda_for_pool3', type=str, default=0, help='chose use tensorboard or not?') # 0
    parser.add_argument('--intra_share', type=int, default=1, help='chose to intra share or not? default is share!') # 1
    parser.add_argument('--used_decouple', type=int, default=0, help='decouple the loss or not? default is no!') # 1
    parser.add_argument('--decouple_type', type=str, default="", help=' decouple type ? Block or Part? default is Nothing!')
    parser.add_argument('--Block_which', type=str, default="", help='which block or blocks fetch from......when use decouple type with Block...!') # 1
    parser.add_argument('--Part_which', type=str, default="", help='which part or parts fetch from......when use decouple type with Part...!') # 1
    parser.add_argument('--loss_ratio_print', type=int, default=0, help='draw the loss ratio, which equals to cls loss / feature loss(which has been scaled by a lamda para.....)!') # 1
    parser.add_argument('--ratio_fix_para', type=int, default=0, help='to auto sets the scale of feature loss with the ratio para, after fetch the exper ratio between cls and feature loss.......')
    parser.add_argument('--lamda1', type=int, default=1, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda2', type=int, default=1, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda3', type=int, default=1, help='chose use tensorboard or not?') # 0
    parser.add_argument('--lamda0', type=str, default="", help='weight for confidence..')
    parser.add_argument('--edge1_type', type=str, default="L2", help='chose use tensorboard or not?') # 0
    parser.add_argument('--edge2_type', type=str, default="", help='chose use tensorboard or not?') # 0
    parser.add_argument('--edge3_type', type=str, default="", help='chose use tensorboard or not?') # 0
    parser.add_argument('--fc_inittype', type=str, default="type8", help="chose init type of fc......... corresponding to yueque text, type8, type7, type6, type5") #
    parser.add_argument('--cal_center', type=int, default=0, help='default is cal center of incremental task')
    parser.add_argument('--train_list', type=str, default='', help='...')
    parser.add_argument('--prompt_store', type=str, default='', help='store the prompt paramter?')
    parser.add_argument('--task_id', type=int, default=0,  help='save the num of the task...?')
    parser.add_argument('--fusion_type', type=str, default='concat', help='pointadd, concat, no_fusion, continual_extract')

    return parser
#===============================================================

if __name__ == '__main__':
    # chose dataset and exp.config

    # 1=======================================================
    data_name = 'cifar'
    """
        str_ = "simplecil"
        str_ = "simplecil_attn"    
        str_ =  "adam_vpt_shallow"
        str_ =   "adam_adapter"
        str_ = "adam_vpt_deep"   
    """
    # str_ = "adam_vpt_deep_with_attn"
    str_ = "adam_vpt_deep"

    # 2=======================================================

    # data_name = 'imagenet_a'
    """
        str_ = "simplecil_imagenet_a"
        str_ =  "adam_vpt_shallow_imagenet_a"
        str_ =   "adam_adapter_imagenet_a"
        str_ = "adam_vpt_deep_imagenet_a"   
    """
    # str_ = "adam_vpt_deep_imagenet_a"

    # End======================================================

    main(data_name, str_)

